{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Updatable Nearest Neighbor Classifier\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates the process of creating an updatable empty k-nearest neighbor model using coremltools."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_of_dimensions = 128\n",
    "\n",
    "from coremltools.models.nearest_neighbors import KNearestNeighborsClassifierBuilder\n",
    "builder = KNearestNeighborsClassifierBuilder(input_name='input',\n",
    "                                             output_name='output',\n",
    "                                             number_of_dimensions=number_of_dimensions,\n",
    "                                             default_class_label='defaultLabel',\n",
    "                                             number_of_neighbors=3,\n",
    "                                             weighting_scheme='inverse_distance',\n",
    "                                             index_type='linear')\n",
    "\n",
    "builder.author = 'Core ML Tools Example'\n",
    "builder.license = 'MIT'\n",
    "builder.description = 'Classifies {} dimension vector based on 3 nearest neighbors'.format(number_of_dimensions)\n",
    "\n",
    "builder.spec.description.input[0].shortDescription = 'Input vector to classify'\n",
    "builder.spec.description.output[0].shortDescription = 'Predicted label. Defaults to \\'defaultLabel\\''\n",
    "builder.spec.description.output[1].shortDescription = 'Probabilities / score for each possible label.'\n",
    "\n",
    "builder.spec.description.trainingInput[0].shortDescription = 'Example input vector'\n",
    "builder.spec.description.trainingInput[1].shortDescription = 'Associated true label of each example vector'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# By default an empty knn model is updatable\n",
    "builder.is_updatable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "128"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's confirm the number of dimension is set correctly\n",
    "builder.number_of_dimensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's check what the value of 'numberOfNeighbors' is\n",
    "builder.number_of_neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 1000)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The number of neighbors is bounded by the default range...\n",
    "builder.number_of_neighbors_allowed_range()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "number_of_neighbors is not within range bounds",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-17-e8bea591e72c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# If we try to set the number of neighbors to a value outside of this range\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mbuilder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumber_of_neighbors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1001\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/eng/sources/coreml/coremltools/coremltools/models/nearest_neighbors/builder.py\u001b[0m in \u001b[0;36mnumber_of_neighbors\u001b[0;34m(self, number_of_neighbors)\u001b[0m\n\u001b[1;32m    312\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkNearestNeighborsClassifier\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumberOfNeighbors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefaultValue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnumber_of_neighbors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    313\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'number_of_neighbors is not within range bounds'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    315\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    316\u001b[0m             \u001b[0mspec_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkNearestNeighborsClassifier\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumberOfNeighbors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: number_of_neighbors is not within range bounds"
     ]
    }
   ],
   "source": [
    "# If we try to set the number of neighbors to a value outside of this range\n",
    "builder.number_of_neighbors = 1001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instead of a range, you can a set individual values that are valid for the numberOfNeighbors parameter.\n",
    "builder.set_number_of_neighbors_with_bounds(3, allowed_set={ 1, 3, 5 })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1, 3, 5}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check out the results of the previous operation\n",
    "builder.number_of_neighbors_allowed_set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "number_of_neighbors is not valid",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-98c77c72c722>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# And now if you attempt to set it to an invalid value...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mbuilder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumber_of_neighbors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/eng/sources/coreml/coremltools/coremltools/models/nearest_neighbors/builder.py\u001b[0m in \u001b[0;36mnumber_of_neighbors\u001b[0;34m(self, number_of_neighbors)\u001b[0m\n\u001b[1;32m    320\u001b[0m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkNearestNeighborsClassifier\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumberOfNeighbors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefaultValue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnumber_of_neighbors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    321\u001b[0m                     \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 322\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'number_of_neighbors is not an allowed value'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    323\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    324\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mset_number_of_neighbors_with_bounds\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumber_of_neighbors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallowed_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallowed_set\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: number_of_neighbors is not valid"
     ]
    }
   ],
   "source": [
    "# And now if you attempt to set it to an invalid value...\n",
    "builder.number_of_neighbors = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And of course you can go back to a valid range\n",
    "builder.set_number_of_neighbors_with_bounds(3, allowed_range=(1, 30))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'linear'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's see what the index type is\n",
    "builder.index_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'kd_tree'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Let's set the index to kd_tree with leaf size of 30\n",
    "builder.set_index_type('kd_tree', 30)\n",
    "builder.index_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlmodel_updatable_path = './UpdatableKNN.mlmodel'\n",
    "\n",
    "# Save the updated spec\n",
    "from coremltools.models import MLModel\n",
    "mlmodel_updatable = MLModel(builder.spec)\n",
    "mlmodel_updatable.save(mlmodel_updatable_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
